{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "import torch\n",
    "torch.cuda.is_available()\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "print(device)\n",
    "\n",
    "from sklearn.metrics import accuracy_score, precision_recall_fscore_support\n",
    "from torch.utils.data import Dataset, DataLoader\n",
    "\n",
    "from transformers import AutoTokenizer, Trainer, TrainingArguments, AutoModelForSequenceClassification, AdamW"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# define a class for the AMP data that will correctly format the sequence information\n",
    "# for fine-tuning with huggingface API\n",
    "\n",
    "class amp_data():\n",
    "    def __init__(self, df, tokenizer_name='Rostlab/prot_bert_bfd', max_len=200):\n",
    "        \n",
    "        self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, do_lower_case=False)\n",
    "        self.max_len = max_len\n",
    "        \n",
    "        self.seqs, self.labels = self.get_seqs_labels()\n",
    "        \n",
    "    def get_seqs_labels(self):        \n",
    "        # isolate the amino acid sequences and their respective AMP labels\n",
    "        seqs = list(df['aa_seq'])\n",
    "        labels = list(df['AMP'].astype(int))\n",
    "        \n",
    "#         assert len(seqs) == len(labels)\n",
    "        return seqs, labels\n",
    "\n",
    "    def __len__(self):\n",
    "        return len(self.labels)\n",
    "    \n",
    "    def __getitem__(self, idx):\n",
    "        seq = \" \".join(\"\".join(self.seqs[idx].split()))\n",
    "        seq_ids = self.tokenizer(seq, truncation=True, padding='max_length', max_length=self.max_len)\n",
    "        \n",
    "        sample = {key: torch.tensor(val) for key, val in seq_ids.items()}\n",
    "        sample['labels'] = torch.tensor(self.labels[idx])\n",
    "\n",
    "        return sample"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "                                                            aa_seq  aa_len  \\\n",
      "AP02151          YEALVTSILGKLTGLWHNDSVDFMGHICYFRRRPKIRRFKLYHEGK...      95   \n",
      "AP01951                                          FLPLVLGALSGILPKIL      17   \n",
      "AP00972                                        FLSLIPHAINAVGVHAKHF      19   \n",
      "AP01261                                           IIEKLVNTALGLLSGL      16   \n",
      "AP01298                                       GLFTLIKCAYQLIAPTVACN      20   \n",
      "AP01802                                     RPWAGNGSVHRYTVLSPRLKTQ      22   \n",
      "UniRef50_Q9UTR1                                SKENSYVEKLLYKQRFYAS      19   \n",
      "\n",
      "                   AMP  \n",
      "AP02151           True  \n",
      "AP01951           True  \n",
      "AP00972           True  \n",
      "AP01261           True  \n",
      "AP01298           True  \n",
      "AP01802           True  \n",
      "UniRef50_Q9UTR1  False  \n"
     ]
    }
   ],
   "source": [
    "# read in the train dataset\n",
    "# create an amp_data class of the dataset\n",
    "\n",
    "df = pd.read_csv('/home/hansol/amp/all_veltri.csv', index_col = 0)\n",
    "df = df.sample(frac=1, random_state = 0)\n",
    "print(df.head(7))\n",
    "\n",
    "train_dataset = amp_data(df)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "# define the necessary metrics for performance evaluation\n",
    "\n",
    "def compute_metrics(pred):\n",
    "    labels = pred.label_ids\n",
    "    preds = pred.predictions.argmax(-1)\n",
    "    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')\n",
    "    acc = accuracy_score(labels, preds)\n",
    "#     conf = confusion_matrix(labels, preds)\n",
    "    return {\n",
    "        'accuracy': acc,\n",
    "        'f1': f1,\n",
    "        'precision': precision,\n",
    "        'recall': recall,\n",
    "#         'confusion matrix': conf\n",
    "    }"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "# define the initializing function for Trainer in huggingface\n",
    "\n",
    "def model_init():\n",
    "    return AutoModelForSequenceClassification.from_pretrained('Rostlab/prot_bert_bfd')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "PyTorch: setting up devices\n",
      "The default value for the training argument `--report_to` will change in v5 (from all installed integrations to none). In v5, you will need to use `--report_to all` to get the same behavior as now. You should start updating your code and make this info disappear :-).\n",
      "loading configuration file https://huggingface.co/Rostlab/prot_bert_bfd/resolve/main/config.json from cache at /home/hansol/.cache/huggingface/transformers/67f460acc7e7e147ff828e909ffe419d00d66ce679c682bc4ab715c107bcbe41.baf557855a8618d0ddfb6c23bfd135bfc38ccf8c3fb099b8df45eb110ccf05e9\n",
      "Model config BertConfig {\n",
      "  \"_name_or_path\": \"Rostlab/prot_bert_bfd\",\n",
      "  \"architectures\": [\n",
      "    \"BertForMaskedLM\"\n",
      "  ],\n",
      "  \"attention_probs_dropout_prob\": 0.0,\n",
      "  \"classifier_dropout\": null,\n",
      "  \"hidden_act\": \"gelu\",\n",
      "  \"hidden_dropout_prob\": 0.0,\n",
      "  \"hidden_size\": 1024,\n",
      "  \"initializer_range\": 0.02,\n",
      "  \"intermediate_size\": 4096,\n",
      "  \"layer_norm_eps\": 1e-12,\n",
      "  \"max_position_embeddings\": 40000,\n",
      "  \"model_type\": \"bert\",\n",
      "  \"num_attention_heads\": 16,\n",
      "  \"num_hidden_layers\": 30,\n",
      "  \"pad_token_id\": 0,\n",
      "  \"position_embedding_type\": \"absolute\",\n",
      "  \"transformers_version\": \"4.19.2\",\n",
      "  \"type_vocab_size\": 2,\n",
      "  \"use_cache\": true,\n",
      "  \"vocab_size\": 30\n",
      "}\n",
      "\n",
      "loading weights file https://huggingface.co/Rostlab/prot_bert_bfd/resolve/main/pytorch_model.bin from cache at /home/hansol/.cache/huggingface/transformers/0a05878f9e3a0d39834dc6f21b88471696d7453a07bac7246152a6ef307c9af4.c5b9869da882baaf70e8e70cf32d81500803511e3220e24457115a03445fa65f\n",
      "Some weights of the model checkpoint at Rostlab/prot_bert_bfd were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight']\n",
      "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at Rostlab/prot_bert_bfd and are newly initialized: ['classifier.weight', 'classifier.bias']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
      "Using amp half precision backend\n",
      "loading configuration file https://huggingface.co/Rostlab/prot_bert_bfd/resolve/main/config.json from cache at /home/hansol/.cache/huggingface/transformers/67f460acc7e7e147ff828e909ffe419d00d66ce679c682bc4ab715c107bcbe41.baf557855a8618d0ddfb6c23bfd135bfc38ccf8c3fb099b8df45eb110ccf05e9\n",
      "Model config BertConfig {\n",
      "  \"_name_or_path\": \"Rostlab/prot_bert_bfd\",\n",
      "  \"architectures\": [\n",
      "    \"BertForMaskedLM\"\n",
      "  ],\n",
      "  \"attention_probs_dropout_prob\": 0.0,\n",
      "  \"classifier_dropout\": null,\n",
      "  \"hidden_act\": \"gelu\",\n",
      "  \"hidden_dropout_prob\": 0.0,\n",
      "  \"hidden_size\": 1024,\n",
      "  \"initializer_range\": 0.02,\n",
      "  \"intermediate_size\": 4096,\n",
      "  \"layer_norm_eps\": 1e-12,\n",
      "  \"max_position_embeddings\": 40000,\n",
      "  \"model_type\": \"bert\",\n",
      "  \"num_attention_heads\": 16,\n",
      "  \"num_hidden_layers\": 30,\n",
      "  \"pad_token_id\": 0,\n",
      "  \"position_embedding_type\": \"absolute\",\n",
      "  \"transformers_version\": \"4.19.2\",\n",
      "  \"type_vocab_size\": 2,\n",
      "  \"use_cache\": true,\n",
      "  \"vocab_size\": 30\n",
      "}\n",
      "\n",
      "loading weights file https://huggingface.co/Rostlab/prot_bert_bfd/resolve/main/pytorch_model.bin from cache at /home/hansol/.cache/huggingface/transformers/0a05878f9e3a0d39834dc6f21b88471696d7453a07bac7246152a6ef307c9af4.c5b9869da882baaf70e8e70cf32d81500803511e3220e24457115a03445fa65f\n",
      "Some weights of the model checkpoint at Rostlab/prot_bert_bfd were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.weight']\n",
      "- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
      "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at Rostlab/prot_bert_bfd and are newly initialized: ['classifier.weight', 'classifier.bias']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n",
      "***** Running training *****\n",
      "  Num examples = 3556\n",
      "  Num Epochs = 15\n",
      "  Instantaneous batch size per device = 1\n",
      "  Total train batch size (w. parallel, distributed & accumulation) = 448\n",
      "  Gradient Accumulation steps = 64\n",
      "  Total optimization steps = 105\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='105' max='105' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [105/105 5:59:35, Epoch 14/15]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>100</td>\n",
       "      <td>0.707100</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "Training completed. Do not forget to share your model on huggingface.co/models =)\n",
      "\n",
      "\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=105, training_loss=0.6926387627919515, metrics={'train_runtime': 21761.5943, 'train_samples_per_second': 2.451, 'train_steps_per_second': 0.005, 'total_flos': 2.4064232304672e+16, 'train_loss': 0.6926387627919515, 'epoch': 14.88})"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# training on entire data\n",
    "# no evaluation/validation\n",
    "\n",
    "training_args = TrainingArguments(\n",
    "    output_dir='./results',          \n",
    "    num_train_epochs=15,              \n",
    "    learning_rate=5e-5,              \n",
    "    per_device_train_batch_size=1,   \n",
    "    warmup_steps=0,               \n",
    "    weight_decay=0.1,               \n",
    "    logging_dir='./logs',            \n",
    "    logging_steps=100,               \n",
    "    do_train=True,                   \n",
    "    do_eval=True,                   \n",
    "    evaluation_strategy=\"no\",    \n",
    "    save_strategy='no',\n",
    "    gradient_accumulation_steps=64,  \n",
    "    fp16=True,                       \n",
    "    fp16_opt_level=\"O2\",             \n",
    "    run_name=\"AMP-BERT\",             \n",
    "    seed=0,                          \n",
    "    load_best_model_at_end = True\n",
    ")\n",
    "\n",
    "trainer = Trainer(\n",
    "    model_init=model_init,                \n",
    "    args=training_args,                   \n",
    "    train_dataset=train_dataset,          \n",
    "    compute_metrics = compute_metrics,    \n",
    ")\n",
    "\n",
    "trainer.train()\n",
    "# trainer.save_model('/home/hansol/amp/model/')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "***** Running Prediction *****\n",
      "  Num examples = 3556\n",
      "  Batch size = 56\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='64' max='64' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [64/64 02:10]\n",
       "    </div>\n",
       "    "
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "{'test_loss': 0.3816959261894226,\n",
       " 'test_accuracy': 0.9569741282339708,\n",
       " 'test_f1': 0.9561478933791918,\n",
       " 'test_precision': 0.9748684979544127,\n",
       " 'test_recall': 0.9381327334083239,\n",
       " 'test_runtime': 132.835,\n",
       " 'test_samples_per_second': 26.77,\n",
       " 'test_steps_per_second': 0.482}"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# performance metrics on the training data itself\n",
    "\n",
    "predictions, label_ids, metrics = trainer.predict(train_dataset)\n",
    "metrics"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
